(*
  Code to recall that some partial functions (of type 'a -> 'b option)
  can be represented as sorted alists, and derive a fast conversion on
  applications of those functions.
*)
structure alist_treeLib :> alist_treeLib =
struct

open HolKernel Parse boolLib simpLib bossLib

open alist_treeTheory comparisonTheory

(* Syntax *)

val alookup_tm = prim_mk_const {Name = "ALOOKUP", Thy = "alist"}

fun mkc nm = prim_mk_const {Name = nm, Thy = "alist_tree"}

val count_append_tm = mkc "count_append"
val is_insert_tm = mkc "is_insert"
val option_choice_tm = mkc "option_choice_f"
val is_lookup_tm = mkc "is_lookup"
val repr_tm = mkc "sorted_alist_repr"

(* trivia *)
val err = mk_HOL_ERR "alist_treeLib"

(* the repr set object *)
datatype 'a alist_reprs = AList_Reprs of {R_thm: thm, conv: conv,
    dest: term -> 'a, cmp: ('a * 'a) -> order,
    dict: (term, thm) Redblackmap.dict ref}

fun mk_alist_reprs R_thm conv dest cmp
    = AList_Reprs {R_thm = R_thm, conv = conv, cmp = cmp,
        dest = dest, dict = ref (Redblackmap.mkDict Term.compare)}

fun peek_functions_in_rs (AList_Reprs inn_rs)
    = Redblackmap.listItems (! (#dict inn_rs)) |> map fst

fun peek_repr (AList_Reprs inn_rs) tm = Redblackmap.peek (! (#dict inn_rs), tm)

(* constructing is_insert thms *)

fun find_key_rec is_last [] = raise Empty
  | find_key_rec is_last (t :: ts) = if listSyntax.is_nil t
  then find_key_rec is_last ts
  else let
    val (f, xs) = strip_comb t
    val do_rev = if is_last then rev else I
  in if same_const f count_append_tm
    then find_key_rec is_last (do_rev (tl xs) @ ts)
    else hd (do_rev (fst (listSyntax.dest_list t))) end

fun hd_key t = find_key_rec false [t] |> pairSyntax.dest_pair |> fst
fun last_key t = find_key_rec true [t] |> pairSyntax.dest_pair |> fst

fun mk_singleton x = listSyntax.mk_list ([x], type_of x)

val simp_count_append = SIMP_CONV bool_ss [count_append_HD_LAST, pairTheory.FST]

fun assume_prems thm = if not (is_imp (concl thm)) then thm
  else let
    val thm = CONV_RULE (RATOR_CONV simp_count_append) thm
    val l_asm = fst (dest_imp (concl thm))
    val prem = if l_asm ~~ T then TRUTH else ASSUME l_asm
  in
    assume_prems (MP thm prem)
  end

fun do_inst_mp insts mp_thm arg_thm = let
    val (prem, _) = dest_imp (concl mp_thm)
    fun rerr e s = let
        val m = s ^ ": " ^ #message e
      in print ("error in do_inst_mp: " ^ m ^ "\n");
        print_thm mp_thm; print "\n"; print_thm arg_thm; print "\n";
        raise (err "do_inst_mp" m) end
    val (more_insts, ty_insts) = match_term prem (concl arg_thm)
      handle HOL_ERR e => rerr e "match_term"
    val ty_i_thm = INST_TYPE ty_insts mp_thm
    val ithm = INST (more_insts @ insts) ty_i_thm
      handle HOL_ERR e => rerr e "INST"
  in MP ithm arg_thm handle HOL_ERR e => rerr e "MP" end

fun build_insert (dest : term -> 'a) cmp R k x =
  let
    val dest_k = dest k
    fun chk thm = if same_const is_insert_tm (fst (strip_comb (concl thm)))
      then thm
      else (print "Not an insert_tm:\n"; print_thm thm; print "\n";
        raise (err "build_insert" "check"))
    val pp = chk o assume_prems
    fun build t = if listSyntax.is_nil t
      then pp (ISPECL [R, k, x] is_insert_to_empty)
      else if listSyntax.is_cons t then let
        val (xs, _) = listSyntax.dest_list t
        val _ = length xs = 1 orelse raise (err "build_insert" "malformed")
        val v = hd xs
      in case cmp (dest_k, dest (fst (pairSyntax.dest_pair v))) of
          EQUAL => pp (ISPECL [R, k, x, v] is_insert_overwrite)
        | GREATER => pp (ISPECL [R, k, x, t] is_insert_far_right)
        | LESS => pp (ISPECL [R, k, x, t] is_insert_far_left)
      end else let
        val (f, xs) = strip_comb t
        val _ = same_const count_append_tm f
           orelse raise (err "build_insert" "unknown")
        val (n, l, r) = case xs of [n, l, r] => (n, l, r)
           | _ => raise (err "build_insert" "num args")
        fun vsub nm v = [mk_var (nm, type_of v) |-> v]
      in if not (cmp (dest_k, dest (hd_key r)) = LESS)
        then do_inst_mp (vsub "l" l) (SPEC n is_insert_r) (build r)
        else if cmp (dest_k, dest (last_key l)) = GREATER
        then ISPECL [R, n, k, x] is_insert_centre
            |> INST (vsub "l" l @ vsub "r" r) |> pp
        else do_inst_mp (vsub "r" r) (SPEC n is_insert_l) (build l)
      end
  in build end

fun prove_assum_by_conv conv thm = let
    val (x,y) = dest_imp (concl thm)
    val thm = CONV_RULE ((RATOR_CONV o RAND_CONV) conv) thm
  in MP thm TRUTH handle HOL_ERR e =>
    (print "Failed to prove assum by conv:\n";
      print_term x;
      print "\n  -- reduced to:\n";
      print_term (fst (dest_imp (concl thm)));
      raise HOL_ERR e)
  end

(* balancing count_append trees *)
fun get_depth tm = let
    val (f, xs) = strip_comb tm
  in if same_const f count_append_tm
    then numSyntax.int_of_term (hd xs)
    else if listSyntax.is_cons tm then 1
    else raise (err "get_depth" "unknown")
  end

fun balance iter bias tm = if iter > 1000 then
    (print "error: looping balance\n"; print_term tm;
        raise (err "balance" "looping"))
  else let
    val (f, xs) = strip_comb tm
    val _ = same_const f count_append_tm orelse raise UNCHANGED
    val _ = is_arb (hd xs) orelse bias <> "N" orelse raise UNCHANGED
    val step_conv = (RAND_CONV (balance 0 "N"))
        THENC (RATOR_CONV (RAND_CONV (balance 0 "N")))
    val thm = QCONV step_conv tm
    val tm = rhs (concl thm)
    val l_sz = get_depth (rand (rator tm))
    val r_sz = get_depth (rand tm)
    val reb = if l_sz > (r_sz + 1) orelse (bias = "R" andalso l_sz > r_sz)
        then "R"
        else if r_sz > (l_sz + 1) orelse (bias = "L" andalso r_sz > l_sz)
        then "L" else "N"
    val conv1 = if reb = "R" then RATOR_CONV (RAND_CONV (balance 0 "L"))
        else if reb = "L" then RAND_CONV (balance 0 "R") else ALL_CONV
    val conv2 = if reb = "R" then REWR_CONV balance_r
        else if reb = "L" then REWR_CONV balance_l else ALL_CONV
    val thm = CONV_RULE (QCONV (RHS_CONV (conv1 THENC conv2 THENC step_conv))) thm
    val tm = rhs (concl thm)
    val l_sz = get_depth (rand (rator tm))
    val r_sz = get_depth (rand tm)
    val sz = numSyntax.term_of_int (1 + Int.max (l_sz, r_sz))
    val set = REWR_CONV (set_count |> SPEC sz)
    val final = if Int.abs (l_sz - r_sz) < 2 then set else balance (iter + 1) "N"
  in CONV_RULE (RHS_CONV final) thm end

fun prove_insert R conv dest cmp k x al = let
    val thm = build_insert dest cmp R k x al
    fun prove thm = if not (is_imp (concl thm)) then thm
      else prove (prove_assum_by_conv conv thm)
  in thm |> DISCH_ALL |> prove |> CONV_RULE (RAND_CONV (balance 0 "N")) end

(* making repr theorems *)

fun mk_insert_repr (AList_Reprs rs) prev_repr k_x = let
    val (k, x) = pairSyntax.dest_pair k_x
    val (R, al) = case strip_comb (concl prev_repr)
      of (_, [R, al, _]) => (R, al)
        | _ => raise (err "mk_insert_repr" "unexpected")
    val insert = prove_insert R (#conv rs) (#dest rs) (#cmp rs) k x al
  in MATCH_MP repr_insert (CONJ prev_repr insert) end

fun dest_alookup_single tm = let
    val (f, xs) = strip_comb tm
  in if not (length xs = 1 andalso same_const f alookup_tm)
    then NONE
    else if listSyntax.is_nil (hd xs) then SOME NONE
    else case total listSyntax.dest_cons (hd xs) of
        SOME (y, ys) => if listSyntax.is_nil ys then SOME (SOME y) else NONE
      | NONE => NONE
  end

fun mk_repr_step rs tm = let
    val (AList_Reprs inn_rs) = rs
    val (f, xs) = strip_comb tm
    val is_short =
        not (option_eq (option_eq aconv) (dest_alookup_single tm) NONE)
    val is_merge = same_const option_choice_tm f
    val is_repr_merge = is_merge andalso (case peek_repr rs (hd xs) of
        SOME _ => true | NONE => false)
    val (is_insert, insert_tm) = if not is_merge then (false, T)
        else case dest_alookup_single (hd xs) of
            SOME (SOME t) => (true, t) | _ => (false, T)
  in if is_short
    then MATCH_MP (ISPEC (hd xs) alist_repr_refl) (#R_thm inn_rs)
        |> prove_assum_by_conv (SIMP_CONV list_ss [sortingTheory.SORTED_DEF])
    else if is_insert
    then mk_insert_repr rs (mk_repr rs (rand tm)) insert_tm
    else if is_repr_merge
    then let
        val l_repr_thm = mk_repr rs (hd xs)
        val l_repr_al = rand (rator (concl l_repr_thm))
        val look = mk_icomb (alookup_tm, l_repr_al)
        val half_repr = list_mk_icomb (option_choice_tm, [look, List.last xs])
        val next_repr = mk_repr rs half_repr
      in MATCH_MP alist_repr_choice_trans_left (CONJ l_repr_thm next_repr) end
    else CHANGED_CONV (SIMP_CONV bool_ss [alookup_to_option_choice,
                option_choice_f_assoc, alookup_empty_option_choice_f,
                count_append_def, alookup_append_option_choice_f,
                empty_is_ALOOKUP]) tm
        handle HOL_ERR _ => raise err "mk_repr_step"
            ("no progress from SIMP_CONV: " ^ Parse.term_to_string tm)
  end
and mk_repr_known_step rs tm =
  case peek_repr rs tm of
    SOME thm => thm
  | NONE => mk_repr_step rs tm
and mk_repr rs tm = let
    val thm = mk_repr_known_step rs tm
  in if is_eq (concl thm)
    then mk_repr rs (rhs (concl thm))
      |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm)))
    else thm
  end

fun add_alist_repr rs thm = let
    val AList_Reprs inn_rs = rs
    val (f, rhs) = dest_eq (concl thm)
    val repr_thm = case peek_repr rs rhs of
        SOME rhs_thm => if is_eq (concl rhs_thm)
          then TRANS thm rhs_thm
          else thm
      | NONE => (mk_repr rs rhs
        |> CONV_RULE (RAND_CONV (REWR_CONV (SYM thm))))
  in
    #dict inn_rs := Redblackmap.insert (! (#dict inn_rs), f, repr_thm)
  end

fun timeit msg f v = let
    val start = Portable.timestamp ()
    val r = f v
    val time = Time.-(Portable.timestamp (), start)
  in print ("Time to " ^ msg ^ ": " ^ Portable.time_to_string time ^ "\n");
    r end

(* testing *)

fun test_rs () = let
    val thm1 = DB.fetch "comparison" "good_cmp_Less_irrefl_trans"
    val thm2 = DB.fetch "comparison" "num_cmp_good"
    val R_thm = MATCH_MP thm1 thm2
  in mk_alist_reprs R_thm EVAL numSyntax.int_of_term Int.compare end

fun test_mk_alookup ns = let
    open numSyntax
    val _ = I
    fun f i = ((i * 157) mod 1000)
    fun el i = pairSyntax.mk_pair (term_of_int (f i), term_of_int i)
  in mk_icomb (alookup_tm, listSyntax.mk_list (map el ns, type_of (el 0))) end

fun test_200 rs = let
    val al1 = test_mk_alookup (upto 1 200)
    val al2 = test_mk_alookup [1, 4, 3]
    val merge = list_mk_icomb (option_choice_tm, [al1, al2])
  in mk_repr rs merge end

(*
val rs = test_rs ()
val thm_200 = timeit "build repr" test_200 rs
*)

(* proving and using is_lookup thms *)

fun build_lookup (dest : term -> 'a) cmp R k =
  let
    val dest_k = dest k
    fun chk thm = if same_const is_lookup_tm (fst (strip_comb (concl thm)))
      then thm
      else (print "Not a lookup_tm:\n"; print_thm thm; print "\n";
        raise (err "build_lookup" "check"))
    val pp = chk o assume_prems
    fun build t = if listSyntax.is_nil t
      then pp (ISPECL [R, k, t] is_lookup_empty)
      else if listSyntax.is_cons t then let
        val (xs, _) = listSyntax.dest_list t
        val _ = length xs = 1 orelse raise (err "build_insert" "malformed")
        val (k', v) = pairSyntax.dest_pair (hd xs)
      in case cmp (dest_k, dest k') of
          EQUAL => pp (ISPECL [R, k, k', v] is_lookup_hit)
        | GREATER => pp (ISPECL [R, k, k', v] is_lookup_far_right)
        | LESS => pp (ISPECL [R, k, k', v] is_lookup_far_left)
      end else let
        val (f, xs) = strip_comb t
        val _ = same_const count_append_tm f
           orelse raise (err "build_lookup" "unknown")
        val (n, l, r) = case xs of [n, l, r] => (n, l, r)
           | _ => raise (err "build_lookup" "num args")
        fun vsub nm v = [mk_var (nm, type_of v) |-> v]
      in if not (cmp (dest_k, dest (hd_key r)) = LESS)
        then do_inst_mp (vsub "l" l) (SPEC n is_lookup_r) (build r)
        else if cmp (dest_k, dest (last_key l)) = GREATER
        then pp (ISPECL [R, n, l, r, k] is_lookup_centre)
        else do_inst_mp (vsub "r" r) (SPEC n is_lookup_l) (build l)
      end
  in build end

fun prove_lookup R conv dest cmp k al = let
    val thm = build_lookup dest cmp R k al
    fun prove thm = if not (is_imp (concl thm)) then thm
      else prove (prove_assum_by_conv conv thm)
  in thm |> DISCH_ALL |> prove end

fun repr_prove_lookup conv dest cmp repr_thm k = let
    val (f, xs) = strip_comb (concl repr_thm)
    val f = same_const f repr_tm orelse
        raise (err "repr_prove_lookup" "unexpected")
    val (R, al, f) = case xs of [R, al, f] => (R, al, f)
        | _ => raise (err "repr_prove_lookup" "num args")
    val lookup = prove_lookup R conv dest cmp k al
  in MATCH_MP lookup_repr (CONJ repr_thm lookup) end

fun reprs_conv rs tm = let
    val AList_Reprs inn_rs = rs
    val (f, x) = dest_comb tm handle HOL_ERR _ => raise UNCHANGED
    val repr_thm = case peek_repr rs f of
      NONE => raise UNCHANGED | SOME thm => thm
  in if is_eq (concl repr_thm)
    then (RATOR_CONV (REWR_CONV repr_thm) THENC reprs_conv rs) tm
    else repr_prove_lookup (#conv inn_rs) (#dest inn_rs) (#cmp inn_rs)
        repr_thm x
  end

fun extract_test f rs i = mk_comb (f, numSyntax.term_of_int i) |> reprs_conv rs

fun extract_test_1000 rs = let
    val alookup = test_mk_alookup (upto 1 300)
    val f = mk_var ("f", type_of alookup)
    val f_def = new_definition ("f", mk_eq (f, alookup))
    val res1 = timeit "add def" (add_alist_repr rs) f_def
    val res2 = timeit "map extract" (map (extract_test f rs)) (upto 1 1000)
  in res2 end

end
